{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8b99e4ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.models import ModelXtoCResNet\n",
    "from src.preprocessing.Derm7pt import preprocessing_Derm7pt\n",
    "from src.utils import *\n",
    "from src.training import run_epoch_x_to_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9333beb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of label columns: 5\n",
      "Found 2022 instances.\n",
      "Created matrix of shape: (2022, 5)\n",
      "Total number of concept columns: 19\n",
      "Found 2013 images.\n",
      "Processing in 32 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 32/32 [00:13<00:00,  2.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 2013 images.\n",
      "Labels shape: (2013, 5)\n",
      "Concepts shape: (2013, 19)\n",
      "Image tensors length: 2013\n",
      "Dataset initialized with 826 pre-sorted items.\n",
      "Dataset initialized with 406 pre-sorted items.\n",
      "Dataset initialized with 790 pre-sorted items.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "concept_labels, train_loader, val_loader, test_loader = preprocessing_Derm7pt(training=True, class_concepts=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f95a35",
   "metadata": {},
   "source": [
    "# Training Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e0e00db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils import find_class_imbalance\n",
    "from src.config import DERM7PT_CONFIG as config_dict\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "72754a34",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIMMED_CONCEPTS = config_dict['N_TRIMMED_CONCEPTS']\n",
    "N_CLASSES = config_dict['N_CLASSES']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0858c66",
   "metadata": {},
   "source": [
    "**Find device to run model on (CPU or GPU).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d7b1efe4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available()\n",
    "                    else \"mps\" if torch.backends.mps.is_available()\n",
    "                    else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdc74960",
   "metadata": {},
   "source": [
    "**Instantiate the model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3793dfc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Instantiated (X -> C)\n"
     ]
    }
   ],
   "source": [
    "model = ModelXtoCResNet(pretrained=True,\n",
    "                freeze=True,\n",
    "                n_concepts=N_TRIMMED_CONCEPTS,\n",
    "                label_mode=True,\n",
    "                n_classes=N_CLASSES)\n",
    "\n",
    "model = model.to(device)\n",
    "print(\"Model Instantiated (X -> C)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24b06149",
   "metadata": {},
   "source": [
    "### Loss\n",
    "We use weighted loss.\n",
    "\n",
    "`BCEWithLogitsLoss()` performs 2 steps:\n",
    "1. $\\sigma(x)$\n",
    "    - Applies the sigmoid function to the logits to get probabilities.\n",
    "2. $\\text{BCE}(\\sigma(x), y) = y \\cdot \\text{log}(\\sigma(x)) + (1-y) \\cdot (1-\\text{log}(\\sigma(x)))$\n",
    "    - Compute binary cross-entropy between output probabilities ($\\sigma(x)$) and ground truths ($y$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "22440cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b0878f",
   "metadata": {},
   "source": [
    "### Optimiser\n",
    "Use same settings as used in CBM repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "008ebdc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizer and Scheduler Ready\n"
     ]
    }
   ],
   "source": [
    "lr = 0.01\n",
    "weight_decay = 0.00004 # same as lambda in L2-regularisation\n",
    "\n",
    "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),\n",
    "                    lr=lr,\n",
    "                    momentum=0.9,\n",
    "                    weight_decay=weight_decay)\n",
    "\n",
    "# scheduler_step = n -> decrease the LR every n epochs\n",
    "scheduler_step = 1000\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)\n",
    "\n",
    "print(\"Optimizer and Scheduler Ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144d34f8",
   "metadata": {},
   "source": [
    "### Training and Validation Loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0eb2c7ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 2\n",
    "log_interval = 50\n",
    "\n",
    "best_val_acc = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1ae2e2ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Epoch 1/2 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                  \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Train Summary | Loss: 1.3285 | Acc: 0.540\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                  \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Val Summary   | Loss: 1.5828 | Acc: 0.522\n",
      "Validation accuracy improved (0.000 -> 0.522). Saving model...\n",
      "Current LR: 0.01\n",
      "--- Epoch 2/2 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                  \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 Train Summary | Loss: 1.0361 | Acc: 0.648\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 Val Summary   | Loss: 1.4273 | Acc: 0.537\n",
      "Validation accuracy improved (0.522 -> 0.537). Saving model...\n",
      "Current LR: 0.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "from src.training import run_epoch_x_to_y\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"--- Epoch {epoch+1}/{epochs} ---\")\n",
    "\n",
    "    # Train\n",
    "    train_loss, train_acc = run_epoch_x_to_y(model, train_loader, attr_criterion, optimizer, is_training=True, device=device, verbose=True)\n",
    "    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')\n",
    "\n",
    "    # Validate\n",
    "    if val_loader:\n",
    "        with torch.no_grad():\n",
    "            val_loss, val_acc = run_epoch_x_to_y(model, val_loader, attr_criterion, optimizer, device=device, verbose=True)\n",
    "\n",
    "        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')\n",
    "\n",
    "        # Save best model based on validation accuracy\n",
    "        if val_acc > best_val_acc:\n",
    "            print(f\"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...\")\n",
    "            best_val_acc = val_acc\n",
    "            # torch.save(model, 'x_to_c_best_model.pth')\n",
    "            # print(\"Model saved to x_to_c_best_model.pth\")\n",
    "\n",
    "    # Scheduler step\n",
    "    scheduler.step()\n",
    "    print(f\"Current LR: {optimizer.param_groups[0]['lr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "19e51089",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                    "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best Model Summary   | Loss: 1.2001 | Acc: 0.605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "if test_loader:\n",
    "    with torch.no_grad():\n",
    "\n",
    "        test_loss, test_acc = run_epoch_x_to_y(model, test_loader, attr_criterion, optimizer, device=device, verbose=True)\n",
    "\n",
    "# print(f\"Shuffled labels shape: {shuffled_img_labels.shape}\")\n",
    "print(f'Best Model Summary   | Loss: {test_loss:.4f} | Acc: {test_acc:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30f3352",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab91783",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d652f20",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
